import sys
import pickle
import pandas as pd
from sklearn import metrics
sys.path.append(sys.argv[3])
from signal_udfs import unpickle
import numpy as np
import statistics

input_handle = sys.argv[1]
output_handle = sys.argv[2]

#Unpickle trace dictionary
pickles = unpickle([sys.argv[4]])
trace_dict = pickles[0]
pickles = unpickle([sys.argv[5]])
peak_dict = pickles[0]

#Create dictionary for trace quantifications
event_names = ['Name', 'Event Baseline Max Z-Score', 'Event Baseline Min Z-Score', 'Event Baseline Mean Z-Score', 'Event Baseline Z-Score AUC (+)','Event Baseline Z-Score AUC (-)', 'Event Baseline Z-Score AUC/s (+)', 'Event Baseline Z-Score AUC/s (-)', 'Event Baseline Max DF/F0', 'Event Baseline Min DF/F0', 'Event Baseline Mean DF/F0', 'Event Baseline DF/F0 AUC (+)', 'Event Baseline DF/F0 AUC (-)', 'Event Baseline DF/F0 AUC/s (+)', 'Event Baseline DF/F0 AUC/s (-)', 'Event Baseline Transients', 'Event Baseline Transients/s', 'Event Baseline Transients Mean Amplitude', 'Event Max Z-Score', 'Event Min Z-Score', 'Event Mean Z-Score', 'Event Z-Score AUC (+)', 'Event Z-Score AUC (-)', 'Event Z-Score AUC/s (+)', 'Event Z-Score AUC/s (-)', 'Event Max DF/F0', 'Event Min DF/F0', 'Event Mean DF/F0', 'Event DF/F0 AUC (+)', 'Event DF/F0 AUC (-)', 'Event DF/F0 AUC/s (+)', 'Event DF/F0 AUC/s (-)', 'Event Transients', 'Event Transients/s', 'Event Transients Mean Amplitude']

event_dict = {name: [] for name in event_names}
event_dict['Name'] = trace_dict['Name']


#Calculate max, min, and mean zscores
names = ['Baseline Trace', 'Event Baseline Z-Score', 'Event Trace', 'Event Z-Score']

for name in names:
	temp_mean = pd.DataFrame(trace_dict[name]).mean(axis = 1, skipna = True).to_list()
	temp_max = pd.DataFrame(trace_dict[name]).max(axis = 1, skipna = True).to_list()
	temp_min = pd.DataFrame(trace_dict[name]).min(axis = 1, skipna = True).to_list()
	if name == 'Baseline Trace':
		event_dict['Event Baseline Max DF/F0'] = temp_max
		event_dict['Event Baseline Min DF/F0'] = temp_min
		event_dict['Event Baseline Mean DF/F0'] = temp_mean
	if name == 'Event Baseline Z-Score':
		event_dict['Event Baseline Max Z-Score'] = temp_max
		event_dict['Event Baseline Min Z-Score'] = temp_min
		event_dict['Event Baseline Mean Z-Score'] = temp_mean
	if name == 'Event Trace':
		event_dict['Event Max DF/F0'] = temp_max
		event_dict['Event Min DF/F0'] = temp_min
		event_dict['Event Mean DF/F0'] = temp_mean
	if name == 'Event Z-Score':
		event_dict['Event Max Z-Score'] = temp_max
		event_dict['Event Min Z-Score'] = temp_min
		event_dict['Event Mean Z-Score'] = temp_mean


#Calculate AUC
names = ['Baseline Trace', 'Event Baseline Z-Score', 'Event Trace', 'Event Z-Score']
time_names = ['Baseline Time', 'Baseline Time', 'Event Time', 'Event Time']

for name_num, name in enumerate(names):
	temp_trace = trace_dict[name]
	temp_time = trace_dict[time_names[name_num]]
	pos_temp_auc = []
	pos_temp_auc_norm = []
	neg_temp_auc = []
	neg_temp_auc_norm = []
	pos_traces = []
	neg_traces = []
	for trace in temp_trace:
		temp_pos = []
		temp_neg = []
		if len(trace) > 1:
			for num in trace:
				if num > 0:
					temp_pos.append(num)
					temp_neg.append(0)
				else:
					temp_pos.append(0)
					temp_neg.append(num)
		else:
			temp_pos.append('NaN')
			temp_neg.append('NaN')
		pos_traces.append(temp_pos)
		neg_traces.append(temp_neg)
	for trace_num, trace in enumerate(pos_traces):
		if len(trace) > 1:
			pos_temp_auc_norm.append(float((metrics.auc(temp_time[trace_num], trace))/(temp_time[trace_num][-1] - temp_time[trace_num][0])))
			pos_temp_auc.append(float(metrics.auc(temp_time[trace_num], trace)))
		else:
			pos_temp_auc_norm.append('NaN')
			pos_temp_auc.append('NaN')
	for trace_num, trace in enumerate(neg_traces):
		if len(trace) > 1:
			neg_temp_auc_norm.append(float((metrics.auc(temp_time[trace_num], trace))/(temp_time[trace_num][-1] - temp_time[trace_num][0])))
			neg_temp_auc.append(float(metrics.auc(temp_time[trace_num], trace)))
		else:
			neg_temp_auc_norm.append('NaN')
			neg_temp_auc.append('NaN')	
	if name == 'Baseline Trace':
		event_dict['Event Baseline DF/F0 AUC/s (+)'].extend(pos_temp_auc_norm)
		event_dict['Event Baseline DF/F0 AUC (+)'].extend(pos_temp_auc)
		event_dict['Event Baseline DF/F0 AUC/s (-)'].extend(neg_temp_auc_norm)
		event_dict['Event Baseline DF/F0 AUC (-)'].extend(neg_temp_auc)
	if name == 'Event Baseline Z-Score':
		event_dict['Event Baseline Z-Score AUC/s (+)'].extend(pos_temp_auc_norm)
		event_dict['Event Baseline Z-Score AUC (+)'].extend(pos_temp_auc)
		event_dict['Event Baseline Z-Score AUC/s (-)'].extend(neg_temp_auc_norm)
		event_dict['Event Baseline Z-Score AUC (-)'].extend(neg_temp_auc)
	if name == 'Event Trace':
		event_dict['Event DF/F0 AUC/s (+)'].extend(pos_temp_auc_norm)
		event_dict['Event DF/F0 AUC (+)'].extend(pos_temp_auc)
		event_dict['Event DF/F0 AUC/s (-)'].extend(neg_temp_auc_norm)
		event_dict['Event DF/F0 AUC (-)'].extend(neg_temp_auc)
	if name == 'Event Z-Score':
		event_dict['Event Z-Score AUC/s (+)'].extend(pos_temp_auc_norm)
		event_dict['Event Z-Score AUC (+)'].extend(pos_temp_auc)
		event_dict['Event Z-Score AUC/s (-)'].extend(neg_temp_auc_norm)
		event_dict['Event Z-Score AUC (-)'].extend(neg_temp_auc)


#Count Transients
time_names = ['Baseline Time', 'Event Time']

for name in time_names:
	temp_counts = []
	temp_counts_norm = []
	temp_amp = []
	temp_amp_mean = []
	for trace in trace_dict[name]:
		if len(trace) > 1:
			peak_counter = 0
			start_ts = trace[0]
			end_ts = trace[-1]
			for peak_num, peak_ts in enumerate(peak_dict['Peak Time']):
				if peak_ts >= start_ts and peak_ts <= end_ts:
					peak_counter += 1
					temp_amp.append(peak_dict['Peak Z'][peak_num])
			temp_counts.append(peak_counter)
			temp_counts_norm.append(peak_counter/(end_ts - start_ts))
			if temp_amp != []:
				temp_amp_mean.append(float(np.nanmean(temp_amp)))
			else:
				temp_amp_mean.append('NaN')
			temp_amp = []
		else:
			temp_counts.append('NaN')
			temp_counts_norm.append('NaN')
			temp_amp.append('NaN')
	if name == 'Baseline Time':
		event_dict['Event Baseline Transients'].extend(temp_counts)
		event_dict['Event Baseline Transients/s'].extend(temp_counts_norm)
		event_dict['Event Baseline Transients Mean Amplitude'].extend(temp_amp_mean)
	if name == 'Event Time':
		event_dict['Event Transients'].extend(temp_counts)
		event_dict['Event Transients/s'].extend(temp_counts_norm)
		event_dict['Event Transients Mean Amplitude'].extend(temp_amp_mean)
	      
#Averaging for event types
struggle_list = []
non_bout_list = []
random_list = []

rearing_list = []
grooming_list = []
onwheel_list = []
running_list = []

for event_num, event in enumerate(event_dict['Name']):
	if event_dict['Event Baseline Z-Score AUC (+)'][event_num] != 'NaN' and event_dict['Event Baseline Z-Score AUC (+)'][event_num] != '':
		if 'struggle' in event.casefold():
			event_tempx = []
			for name in event_names[1:]:
				event_tempx.append(event_dict[name][event_num])
			struggle_list.append(event_tempx)
		if 'non' in event.casefold():
			event_tempx = []
			for name in event_names[1:]:
				event_tempx.append(event_dict[name][event_num])
			non_bout_list.append(event_tempx)
		if 'random' in event.casefold():
			event_tempx = []
			for name in event_names[1:]:
				event_tempx.append(event_dict[name][event_num])
			random_list.append(event_tempx)
		if 'rearing' in event.casefold():
			event_tempx = []
			for name in event_names[1:]:
				event_tempx.append(event_dict[name][event_num])
			rearing_list.append(event_tempx)
		if 'grooming' in event.casefold():
			event_tempx = []
			for name in event_names[1:]:
				event_tempx.append(event_dict[name][event_num])
			grooming_list.append(event_tempx)
		if 'onwheel' in event.casefold():
			event_tempx = []
			for name in event_names[1:]:
				event_tempx.append(event_dict[name][event_num])
			onwheel_list.append(event_tempx)
		if 'running' in event.casefold():
			event_tempx = []
			for name in event_names[1:]:
				event_tempx.append(event_dict[name][event_num])
			running_list.append(event_tempx)

event_df = pd.DataFrame.from_dict(event_dict)


if struggle_list != []:
	col_list = [[] for x in range(len(struggle_list[0]))]
	col_mean = []
	for row in struggle_list:
		for item_num, item in enumerate(row):
			if 'str' not in str(type(item)).casefold():
				col_list[item_num].append(item)
	for line in col_list:
                if line != []:
                        col_mean.append(statistics.mean(line))
                else:
                        col_mean.append('NaN')
	col_mean.insert(0, 'Struggle Average')
	df = pd.DataFrame(col_mean)
	event_df.loc[len(event_df)] = col_mean

if non_bout_list != []:
	col_list = [[] for x in range(len(non_bout_list[0]))]
	col_mean = []
	for row in non_bout_list:
		for item_num, item in enumerate(row):
			if 'str' not in str(type(item)).casefold():
				col_list[item_num].append(item)
	for line in col_list:
                if line != []:
                        col_mean.append(statistics.mean(line))
                else:
                        col_mean.append('NaN')
	col_mean.insert(0, 'Non-Bout Average')
	df = pd.DataFrame(col_mean)
	event_df.loc[len(event_df)] = col_mean
	
if random_list != []:
	col_list = [[] for x in range(len(random_list[0]))]
	col_mean = []
	for row in random_list:
		for item_num, item in enumerate(row):
			if 'str' not in str(type(item)).casefold():
				col_list[item_num].append(item)
	for line in col_list:
                if line != []:
                        col_mean.append(statistics.mean(line))
                else:
                        col_mean.append('NaN')
	col_mean.insert(0, 'Random Sample Average')
	df = pd.DataFrame(col_mean)
	event_df.loc[len(event_df)] = col_mean

if rearing_list != []:
	col_list = [[] for x in range(len(rearing_list[0]))]
	col_mean = []
	for row in rearing_list:
		for item_num, item in enumerate(row):
			if 'str' not in str(type(item)).casefold():
				col_list[item_num].append(item)
	for line in col_list:
                if line != []:
                        col_mean.append(statistics.mean(line))
                else:
                        col_mean.append('NaN')
	col_mean.insert(0, 'Rearing Average')
	df = pd.DataFrame(col_mean)
	event_df.loc[len(event_df)] = col_mean

if grooming_list != []:
	col_list = [[] for x in range(len(grooming_list[0]))]
	col_mean = []
	for row in grooming_list:
		for item_num, item in enumerate(row):
			if 'str' not in str(type(item)).casefold():
				col_list[item_num].append(item)
	for line in col_list:
                if line != []:
                        col_mean.append(statistics.mean(line))
                else:
                        col_mean.append('NaN')
	col_mean.insert(0, 'Grooming Average')
	df = pd.DataFrame(col_mean)
	event_df.loc[len(event_df)] = col_mean

if onwheel_list != []:
	col_list = [[] for x in range(len(onwheel_list[0]))]
	col_mean = []
	for row in onwheel_list:
		for item_num, item in enumerate(row):
			if 'str' not in str(type(item)).casefold():
				col_list[item_num].append(item)
	for line in col_list:
                if line != []:
                        col_mean.append(statistics.mean(line))
                else:
                        col_mean.append('NaN')
	col_mean.insert(0, 'On Wheel Average')
	df = pd.DataFrame(col_mean)
	event_df.loc[len(event_df)] = col_mean

if running_list != []:
	col_list = [[] for x in range(len(running_list[0]))]
	col_mean = []
	for row in running_list:
		for item_num, item in enumerate(row):
			if 'str' not in str(type(item)).casefold():
				col_list[item_num].append(item)
	for line in col_list:
                if line != []:
                        col_mean.append(statistics.mean(line))
                else:
                        col_mean.append('NaN')
	col_mean.insert(0, 'Running Average')
	df = pd.DataFrame(col_mean)
	event_df.loc[len(event_df)] = col_mean

###Save calc dicts as .csv
event_df.to_csv(output_handle[:-4] + '_4s_peri-event_calcs.csv', index = False)
